1 package org.apache.lucene.search.join;
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 import java.io.IOException;
21 import java.util.Arrays;
22 import java.util.HashMap;
23 import java.util.LinkedList;
24 import java.util.Map;
25 import java.util.Queue;
26
27 import org.apache.lucene.index.IndexWriter;
28 import org.apache.lucene.index.LeafReaderContext;
29 import org.apache.lucene.search.Collector;
30 import org.apache.lucene.search.FieldComparator;
31 import org.apache.lucene.search.FieldValueHitQueue;
32 import org.apache.lucene.search.LeafCollector;
33 import org.apache.lucene.search.LeafFieldComparator;
34 import org.apache.lucene.search.Query;
35 import org.apache.lucene.search.ScoreCachingWrappingScorer;
36 import org.apache.lucene.search.Scorer;
37 import org.apache.lucene.search.Scorer.ChildScorer;
38 import org.apache.lucene.search.Sort;
39 import org.apache.lucene.search.TopDocs;
40 import org.apache.lucene.search.TopDocsCollector;
41 import org.apache.lucene.search.TopFieldCollector;
42 import org.apache.lucene.search.TopScoreDocCollector;
43 import org.apache.lucene.search.grouping.GroupDocs;
44 import org.apache.lucene.search.grouping.TopGroups;
45 import org.apache.lucene.util.ArrayUtil;
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96 public class ToParentBlockJoinCollector implements Collector {
97
98 private final Sort sort;
99
100
101
102 private final Map<Query,Integer> joinQueryID = new HashMap<>();
103 private final int numParentHits;
104 private final FieldValueHitQueue<OneGroup> queue;
105 private final FieldComparator<?>[] comparators;
106 private final boolean trackMaxScore;
107 private final boolean trackScores;
108
109 private ToParentBlockJoinQuery.BlockJoinScorer[] joinScorers = new ToParentBlockJoinQuery.BlockJoinScorer[0];
110 private boolean queueFull;
111
112 private OneGroup bottom;
113 private int totalHitCount;
114 private float maxScore = Float.NaN;
115
116
117
118
119
120 public ToParentBlockJoinCollector(Sort sort, int numParentHits, boolean trackScores, boolean trackMaxScore) throws IOException {
121
122
123 this.sort = sort;
124 this.trackMaxScore = trackMaxScore;
125 if (trackMaxScore) {
126 maxScore = Float.MIN_VALUE;
127 }
128
129 this.trackScores = trackScores;
130 this.numParentHits = numParentHits;
131 queue = FieldValueHitQueue.create(sort.getSort(), numParentHits);
132 comparators = queue.getComparators();
133 }
134
135 private static final class OneGroup extends FieldValueHitQueue.Entry {
136 public OneGroup(int comparatorSlot, int parentDoc, float parentScore, int numJoins, boolean doScores) {
137 super(comparatorSlot, parentDoc, parentScore);
138
139 docs = new int[numJoins][];
140 for(int joinID=0;joinID<numJoins;joinID++) {
141 docs[joinID] = new int[5];
142 }
143 if (doScores) {
144 scores = new float[numJoins][];
145 for(int joinID=0;joinID<numJoins;joinID++) {
146 scores[joinID] = new float[5];
147 }
148 }
149 counts = new int[numJoins];
150 }
151 LeafReaderContext readerContext;
152 int[][] docs;
153 float[][] scores;
154 int[] counts;
155 }
156
157 @Override
158 public LeafCollector getLeafCollector(final LeafReaderContext context)
159 throws IOException {
160 final LeafFieldComparator[] comparators = queue.getComparators(context);
161 final int[] reverseMul = queue.getReverseMul();
162 final int docBase = context.docBase;
163 return new LeafCollector() {
164
165 private Scorer scorer;
166
167 @Override
168 public void setScorer(Scorer scorer) throws IOException {
169
170
171
172
173 if (scorer instanceof ScoreCachingWrappingScorer == false) {
174 scorer = new ScoreCachingWrappingScorer(scorer);
175 }
176 this.scorer = scorer;
177 for (LeafFieldComparator comparator : comparators) {
178 comparator.setScorer(scorer);
179 }
180 Arrays.fill(joinScorers, null);
181
182 Queue<Scorer> queue = new LinkedList<>();
183
184 queue.add(scorer);
185 while ((scorer = queue.poll()) != null) {
186
187 if (scorer instanceof ToParentBlockJoinQuery.BlockJoinScorer) {
188 enroll((ToParentBlockJoinQuery) scorer.getWeight().getQuery(), (ToParentBlockJoinQuery.BlockJoinScorer) scorer);
189 }
190
191 for (ChildScorer sub : scorer.getChildren()) {
192
193 queue.add(sub.child);
194 }
195 }
196 }
197
198 @Override
199 public void collect(int parentDoc) throws IOException {
200
201 totalHitCount++;
202
203 float score = Float.NaN;
204
205 if (trackMaxScore) {
206 score = scorer.score();
207 maxScore = Math.max(maxScore, score);
208 }
209
210
211
212
213
214 if (queueFull) {
215
216
217 int c = 0;
218 for (int i = 0; i < comparators.length; ++i) {
219 c = reverseMul[i] * comparators[i].compareBottom(parentDoc);
220 if (c != 0) {
221 break;
222 }
223 }
224 if (c <= 0) {
225
226
227 return;
228 }
229
230
231
232
233 for (LeafFieldComparator comparator : comparators) {
234 comparator.copy(bottom.slot, parentDoc);
235 }
236 if (!trackMaxScore && trackScores) {
237 score = scorer.score();
238 }
239 bottom.doc = docBase + parentDoc;
240 bottom.readerContext = context;
241 bottom.score = score;
242 copyGroups(bottom);
243 bottom = queue.updateTop();
244
245 for (LeafFieldComparator comparator : comparators) {
246 comparator.setBottom(bottom.slot);
247 }
248 } else {
249
250 final int comparatorSlot = totalHitCount - 1;
251
252
253 for (LeafFieldComparator comparator : comparators) {
254 comparator.copy(comparatorSlot, parentDoc);
255 }
256
257 if (!trackMaxScore && trackScores) {
258 score = scorer.score();
259 }
260 final OneGroup og = new OneGroup(comparatorSlot, docBase+parentDoc, score, joinScorers.length, trackScores);
261 og.readerContext = context;
262 copyGroups(og);
263 bottom = queue.add(og);
264 queueFull = totalHitCount == numParentHits;
265 if (queueFull) {
266
267 for (LeafFieldComparator comparator : comparators) {
268 comparator.setBottom(bottom.slot);
269 }
270 }
271 }
272 }
273
274
275 private void copyGroups(OneGroup og) {
276
277
278
279 final int numSubScorers = joinScorers.length;
280 if (og.docs.length < numSubScorers) {
281
282
283
284 og.docs = ArrayUtil.grow(og.docs);
285 }
286 if (og.counts.length < numSubScorers) {
287 og.counts = ArrayUtil.grow(og.counts);
288 }
289 if (trackScores && og.scores.length < numSubScorers) {
290 og.scores = ArrayUtil.grow(og.scores);
291 }
292
293
294 for(int scorerIDX = 0;scorerIDX < numSubScorers;scorerIDX++) {
295 final ToParentBlockJoinQuery.BlockJoinScorer joinScorer = joinScorers[scorerIDX];
296
297 if (joinScorer != null && docBase + joinScorer.getParentDoc() == og.doc) {
298 og.counts[scorerIDX] = joinScorer.getChildCount();
299
300 og.docs[scorerIDX] = joinScorer.swapChildDocs(og.docs[scorerIDX]);
301 assert og.docs[scorerIDX].length >= og.counts[scorerIDX]: "length=" + og.docs[scorerIDX].length + " vs count=" + og.counts[scorerIDX];
302
303
304
305
306
307
308 if (trackScores) {
309
310 og.scores[scorerIDX] = joinScorer.swapChildScores(og.scores[scorerIDX]);
311 assert og.scores[scorerIDX].length >= og.counts[scorerIDX]: "length=" + og.scores[scorerIDX].length + " vs count=" + og.counts[scorerIDX];
312 }
313 } else {
314 og.counts[scorerIDX] = 0;
315 }
316 }
317 }
318 };
319 }
320
321 private void enroll(ToParentBlockJoinQuery query, ToParentBlockJoinQuery.BlockJoinScorer scorer) {
322 scorer.trackPendingChildHits();
323 final Integer slot = joinQueryID.get(query);
324 if (slot == null) {
325 joinQueryID.put(query, joinScorers.length);
326
327 final ToParentBlockJoinQuery.BlockJoinScorer[] newArray = new ToParentBlockJoinQuery.BlockJoinScorer[1+joinScorers.length];
328 System.arraycopy(joinScorers, 0, newArray, 0, joinScorers.length);
329 joinScorers = newArray;
330 joinScorers[joinScorers.length-1] = scorer;
331 } else {
332 joinScorers[slot] = scorer;
333 }
334 }
335
336 private OneGroup[] sortedGroups;
337
338 private void sortQueue() {
339 sortedGroups = new OneGroup[queue.size()];
340 for(int downTo=queue.size()-1;downTo>=0;downTo--) {
341 sortedGroups[downTo] = queue.pop();
342 }
343 }
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361 public TopGroups<Integer> getTopGroups(ToParentBlockJoinQuery query, Sort withinGroupSort, int offset,
362 int maxDocsPerGroup, int withinGroupOffset, boolean fillSortFields)
363 throws IOException {
364
365 final Integer _slot = joinQueryID.get(query);
366 if (_slot == null && totalHitCount == 0) {
367 return null;
368 }
369
370 if (sortedGroups == null) {
371 if (offset >= queue.size()) {
372 return null;
373 }
374 sortQueue();
375 } else if (offset > sortedGroups.length) {
376 return null;
377 }
378
379 return accumulateGroups(_slot == null ? -1 : _slot.intValue(), offset, maxDocsPerGroup, withinGroupOffset, withinGroupSort, fillSortFields);
380 }
381
382
383
384
385
386
387
388
389
390
391
392
393
394 @SuppressWarnings({"unchecked","rawtypes"})
395 private TopGroups<Integer> accumulateGroups(int slot, int offset, int maxDocsPerGroup,
396 int withinGroupOffset, Sort withinGroupSort, boolean fillSortFields) throws IOException {
397 final GroupDocs<Integer>[] groups = new GroupDocs[sortedGroups.length - offset];
398 final FakeScorer fakeScorer = new FakeScorer();
399
400 int totalGroupedHitCount = 0;
401
402
403 for(int groupIDX=offset;groupIDX<sortedGroups.length;groupIDX++) {
404 final OneGroup og = sortedGroups[groupIDX];
405 final int numChildDocs;
406 if (slot == -1 || slot >= og.counts.length) {
407 numChildDocs = 0;
408 } else {
409 numChildDocs = og.counts[slot];
410 }
411
412
413 final int numDocsInGroup = Math.max(1, Math.min(numChildDocs, maxDocsPerGroup));
414
415
416
417
418 final TopDocsCollector<?> collector;
419 if (withinGroupSort == null) {
420
421
422 if (!trackScores) {
423 throw new IllegalArgumentException("cannot sort by relevance within group: trackScores=false");
424 }
425 collector = TopScoreDocCollector.create(numDocsInGroup);
426 } else {
427
428 collector = TopFieldCollector.create(withinGroupSort, numDocsInGroup, fillSortFields, trackScores, trackMaxScore);
429 }
430
431 LeafCollector leafCollector = collector.getLeafCollector(og.readerContext);
432 leafCollector.setScorer(fakeScorer);
433 for(int docIDX=0;docIDX<numChildDocs;docIDX++) {
434
435 final int doc = og.docs[slot][docIDX];
436 fakeScorer.doc = doc;
437 if (trackScores) {
438 fakeScorer.score = og.scores[slot][docIDX];
439 }
440 leafCollector.collect(doc);
441 }
442 totalGroupedHitCount += numChildDocs;
443
444 final Object[] groupSortValues;
445
446 if (fillSortFields) {
447 groupSortValues = new Object[comparators.length];
448 for(int sortFieldIDX=0;sortFieldIDX<comparators.length;sortFieldIDX++) {
449 groupSortValues[sortFieldIDX] = comparators[sortFieldIDX].value(og.slot);
450 }
451 } else {
452 groupSortValues = null;
453 }
454
455 final TopDocs topDocs = collector.topDocs(withinGroupOffset, numDocsInGroup);
456
457 groups[groupIDX-offset] = new GroupDocs<>(og.score,
458 topDocs.getMaxScore(),
459 numChildDocs,
460 topDocs.scoreDocs,
461 og.doc,
462 groupSortValues);
463 }
464
465 return new TopGroups<>(new TopGroups<>(sort.getSort(),
466 withinGroupSort == null ? null : withinGroupSort.getSort(),
467 0, totalGroupedHitCount, groups, maxScore),
468 totalHitCount);
469 }
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485 public TopGroups<Integer> getTopGroupsWithAllChildDocs(ToParentBlockJoinQuery query, Sort withinGroupSort, int offset,
486 int withinGroupOffset, boolean fillSortFields)
487 throws IOException {
488
489 return getTopGroups(query, withinGroupSort, offset, Integer.MAX_VALUE, withinGroupOffset, fillSortFields);
490 }
491
492
493
494
495
496
497
498 public float getMaxScore() {
499 return maxScore;
500 }
501
502 @Override
503 public boolean needsScores() {
504
505
506 return true;
507 }
508 }